{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f0759f2-5e46-438a-ad8e-b5d5771ec9ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# RAG based Gradio solution to give information from related documents, using Llama3.2 and nomic-embed-text over OLLAMA\n",
    "# Took help of Claude and Course material."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "448bd8f4-9181-4039-829f-d3f0a5f14171",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, glob\n",
    "import sqlite3\n",
    "import json\n",
    "import numpy as np\n",
    "from typing import List, Dict, Tuple\n",
    "import requests\n",
    "import gradio as gr\n",
    "from datetime import datetime\n",
    "\n",
    "embedding_model = 'nomic-embed-text'\n",
    "llm_model = 'llama3.2'\n",
    "RagDist_k = 6\n",
    "folders = glob.glob(\"../../week5/knowledge-base/*\")\n",
    "folders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc085852-a80f-4f2c-b31a-80ceda10bec6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class OllamaEmbeddings:\n",
    "    \"\"\"Generate embeddings using Ollama's embedding models.\"\"\"\n",
    "    \n",
    "    def __init__(self, model: str = embedding_model, base_url: str = \"http://localhost:11434\"):\n",
    "        self.model = model\n",
    "        self.base_url = base_url\n",
    "        \n",
    "    def embed_text(self, text: str) -> List[float]:\n",
    "        \"\"\"Generate embedding for a single text.\"\"\"\n",
    "        print('Processing', text[:70].replace('\\n',' | '))\n",
    "        response = requests.post(\n",
    "            f\"{self.base_url}/api/embeddings\",\n",
    "            json={\"model\": self.model, \"prompt\": text}\n",
    "        )\n",
    "        if response.status_code == 200:\n",
    "            return response.json()[\"embedding\"]\n",
    "        else:\n",
    "            raise Exception(f\"Error generating embedding: {response.text}\")\n",
    "    \n",
    "    def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
    "        \"\"\"Generate embeddings for multiple texts.\"\"\"\n",
    "        return [self.embed_text(text) for text in texts]\n",
    "\n",
    "\n",
    "class SQLiteVectorStore:\n",
    "    \"\"\"Vector store using SQLite for storing and retrieving document embeddings.\"\"\"\n",
    "    \n",
    "    def __init__(self, db_path: str = \"vector_store.db\"):\n",
    "        self.db_path = db_path\n",
    "        self.conn = sqlite3.connect(db_path, check_same_thread=False)\n",
    "        self._create_table()\n",
    "    \n",
    "    def _create_table(self):\n",
    "        \"\"\"Create the documents table if it doesn't exist.\"\"\"\n",
    "        cursor = self.conn.cursor()\n",
    "        cursor.execute(\"\"\"\n",
    "            CREATE TABLE IF NOT EXISTS documents (\n",
    "                id INTEGER PRIMARY KEY AUTOINCREMENT,\n",
    "                content TEXT NOT NULL,\n",
    "                embedding TEXT NOT NULL,\n",
    "                metadata TEXT,\n",
    "                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP\n",
    "            )\n",
    "        \"\"\")\n",
    "        self.conn.commit()\n",
    "    \n",
    "    def add_documents(self, texts: List[str], embeddings: List[List[float]], \n",
    "                     metadatas: List[Dict] = None):\n",
    "        \"\"\"Add documents with their embeddings to the store.\"\"\"\n",
    "        cursor = self.conn.cursor()\n",
    "        if metadatas is None:\n",
    "            metadatas = [{}] * len(texts)\n",
    "        \n",
    "        for text, embedding, metadata in zip(texts, embeddings, metadatas):\n",
    "            cursor.execute(\"\"\"\n",
    "                INSERT INTO documents (content, embedding, metadata)\n",
    "                VALUES (?, ?, ?)\n",
    "            \"\"\", (text, json.dumps(embedding), json.dumps(metadata)))\n",
    "        \n",
    "        self.conn.commit()\n",
    "    \n",
    "    def cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:\n",
    "        \"\"\"Calculate cosine similarity between two vectors.\"\"\"\n",
    "        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))\n",
    "    \n",
    "    def similarity_search(self, query_embedding: List[float], k: int = 3) -> List[Tuple[str, float, Dict]]:\n",
    "        \"\"\"Search for the k most similar documents.\"\"\"\n",
    "        cursor = self.conn.cursor()\n",
    "        cursor.execute(\"SELECT content, embedding, metadata FROM documents\")\n",
    "        results = cursor.fetchall()\n",
    "        \n",
    "        query_vec = np.array(query_embedding)\n",
    "        similarities = []\n",
    "        \n",
    "        for content, embedding_json, metadata_json in results:\n",
    "            doc_vec = np.array(json.loads(embedding_json))\n",
    "            similarity = self.cosine_similarity(query_vec, doc_vec)\n",
    "            similarities.append((content, similarity, json.loads(metadata_json)))\n",
    "        \n",
    "        # Sort by similarity (highest first) and return top k\n",
    "        similarities.sort(key=lambda x: x[1], reverse=True)\n",
    "        return similarities[:k]\n",
    "    \n",
    "    def clear_all(self):\n",
    "        \"\"\"Clear all documents from the store.\"\"\"\n",
    "        cursor = self.conn.cursor()\n",
    "        cursor.execute(\"DELETE FROM documents\")\n",
    "        self.conn.commit()\n",
    "    \n",
    "    def get_document_count(self) -> int:\n",
    "        \"\"\"Get the total number of documents in the store.\"\"\"\n",
    "        cursor = self.conn.cursor()\n",
    "        cursor.execute(\"SELECT COUNT(*) FROM documents\")\n",
    "        return cursor.fetchone()[0]\n",
    "\n",
    "\n",
    "class OllamaLLM:\n",
    "    \"\"\"Interact with Ollama LLM for text generation.\"\"\"\n",
    "    \n",
    "    def __init__(self, model: str = llm_model, base_url: str = \"http://localhost:11434\"):\n",
    "        self.model = model\n",
    "        self.base_url = base_url\n",
    "    \n",
    "    def generate(self, prompt: str, stream: bool = False) -> str:\n",
    "        \"\"\"Generate text from the LLM.\"\"\"\n",
    "        response = requests.post(\n",
    "            f\"{self.base_url}/api/generate\",\n",
    "            json={\"model\": self.model, \"prompt\": prompt, \"stream\": stream}\n",
    "        )\n",
    "        \n",
    "        if response.status_code == 200:\n",
    "            return response.json()[\"response\"]\n",
    "        else:\n",
    "            raise Exception(f\"Error generating response: {response.text}\")\n",
    "\n",
    "\n",
    "class RAGSystem:\n",
    "    \"\"\"RAG system combining vector store, embeddings, and LLM.\"\"\"\n",
    "    \n",
    "    def __init__(self, embedding_model: str = embedding_model, \n",
    "                 llm_model: str = llm_model,\n",
    "                 db_path: str = \"vector_store.db\"):\n",
    "        self.embeddings = OllamaEmbeddings(model=embedding_model)\n",
    "        self.vector_store = SQLiteVectorStore(db_path=db_path)\n",
    "        self.llm = OllamaLLM(model=llm_model)\n",
    "    \n",
    "    def add_documents(self, documents: List[Dict[str, str]]):\n",
    "        \"\"\"\n",
    "        Add documents to the RAG system.\n",
    "        documents: List of dicts with 'content' and optional 'metadata'\n",
    "        \"\"\"\n",
    "        texts = [doc['content'] for doc in documents]\n",
    "        metadatas = [doc.get('metadata', {}) for doc in documents]\n",
    "        \n",
    "        print(f\"Generating embeddings for {len(texts)} documents...\")\n",
    "        embeddings = self.embeddings.embed_documents(texts)\n",
    "        \n",
    "        print(\"Storing documents in vector store...\")\n",
    "        self.vector_store.add_documents(texts, embeddings, metadatas)\n",
    "        print(f\"Successfully added {len(texts)} documents!\")\n",
    "    \n",
    "    def query(self, question: str, k: int = 3) -> str:\n",
    "        \"\"\"Query the RAG system with a question.\"\"\"\n",
    "        # Generate embedding for the query\n",
    "        query_embedding = self.embeddings.embed_text(question)\n",
    "        \n",
    "        # Retrieve relevant documents\n",
    "        results = self.vector_store.similarity_search(query_embedding, k=k)\n",
    "        \n",
    "        if not results:\n",
    "            return \"I don't have any information to answer this question.\"\n",
    "        \n",
    "        # Build context from retrieved documents\n",
    "        context = \"\\n\\n\".join([\n",
    "            f\"Document {i+1} (Relevance: {score:.2f}):\\n{content}\"\n",
    "            for i, (content, score, _) in enumerate(results)\n",
    "        ])\n",
    "        \n",
    "        # Create prompt for LLM\n",
    "        prompt = f\"\"\"You are a helpful assistant answering questions based on the provided context.\n",
    "            Use the following context to answer the question. If you cannot answer the question based on the context, say so.\n",
    "            \n",
    "            Context:\n",
    "            {context}\n",
    "            \n",
    "            Question: {question}\n",
    "            \n",
    "            Answer:\"\"\"\n",
    "        \n",
    "        # Generate response\n",
    "        response = self.llm.generate(prompt)\n",
    "        return response\n",
    "    \n",
    "    def get_stats(self) -> str:\n",
    "        \"\"\"Get statistics about the RAG system.\"\"\"\n",
    "        doc_count = self.vector_store.get_document_count()\n",
    "        return f\"Total documents in database: {doc_count}\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37cbaa24-6e17-4712-8c90-429264b9b82e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_documents() -> List[Dict[str, str]]:\n",
    "    \"\"\"\n",
    "    Read all files from specified folders and format them for RAG system.    \n",
    "    Args:\n",
    "        folders: List of folder paths to read files from\n",
    "    Returns:\n",
    "        List of dictionaries with 'content' and 'metadata' keys\n",
    "    \"\"\"\n",
    "    from pathlib import Path\n",
    "    \n",
    "    documents = []\n",
    "    supported_extensions = {'.md'}\n",
    "    \n",
    "    for folder in folders:\n",
    "        folder_path = Path(folder)\n",
    "        \n",
    "        if not folder_path.exists():\n",
    "            print(f\"Warning: Folder '{folder}' does not exist. Skipping...\")\n",
    "            continue\n",
    "        \n",
    "        if not folder_path.is_dir():\n",
    "            print(f\"Warning: '{folder}' is not a directory. Skipping...\")\n",
    "            continue\n",
    "        \n",
    "        folder_name = folder_path.name\n",
    "        \n",
    "        # Get all files in the folder\n",
    "        files = [f for f in folder_path.iterdir() if f.is_file()]\n",
    "        \n",
    "        for file_path in files:\n",
    "            # Check if file extension is supported\n",
    "            if file_path.suffix.lower() not in supported_extensions:\n",
    "                print(f\"Skipping unsupported file type: {file_path.name}\")\n",
    "                continue\n",
    "            \n",
    "            try:\n",
    "                # Read file content\n",
    "                with open(file_path, 'r', encoding='utf-8') as f:\n",
    "                    content = f.read()\n",
    "                \n",
    "                # Create document dictionary\n",
    "                document = {\n",
    "                    'metadata': {\n",
    "                        'type': folder_name,\n",
    "                        'name': file_path.name,\n",
    "                        'datalen': len(content)\n",
    "                    },\n",
    "                    'content': content,\n",
    "                }\n",
    "                \n",
    "                documents.append(document)\n",
    "                print(f\"✓ Loaded: {file_path.name} from folder '{folder_name}'\")\n",
    "                \n",
    "            except Exception as e:\n",
    "                print(f\"Error reading file {file_path.name}: {str(e)}\")\n",
    "                continue\n",
    "    \n",
    "    print(f\"\\nTotal documents loaded: {len(documents)}\")\n",
    "    return documents\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d257bd84-fd7b-4a64-bc5b-148b30b00aa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_gradio_interface(rag_system: RAGSystem):\n",
    "    \"\"\"Create Gradio chat interface for the RAG system.\"\"\"\n",
    "    \n",
    "    def chat_fn(message, history):\n",
    "        \"\"\"Process chat messages.\"\"\"\n",
    "        try:\n",
    "            response = rag_system.query(message, k=RagDist_k)\n",
    "            return response\n",
    "        except Exception as e:\n",
    "            return f\"Error: {str(e)}\\n\\nMake sure Ollama is running with the required models installed.\"\n",
    "    \n",
    "    def load_data():\n",
    "        \"\"\"Load sample documents into the system.\"\"\"\n",
    "        try:\n",
    "            documents = load_documents()\n",
    "            rag_system.add_documents(documents)\n",
    "            stats = rag_system.get_stats()\n",
    "            return f\"✅ Sample documents loaded successfully!\\n{stats}\"\n",
    "        except Exception as e:\n",
    "            return f\"❌ Error loading documents: {str(e)}\"\n",
    "    \n",
    "    def get_stats():\n",
    "        \"\"\"Get system statistics.\"\"\"\n",
    "        return rag_system.get_stats()\n",
    "    \n",
    "    with gr.Blocks(title=\"RAG System - Company Knowledge Base\", theme=gr.themes.Soft()) as demo:\n",
    "        gr.Markdown(\"# 🤖 RAG System - Company Knowledge Base\")\n",
    "        gr.Markdown(\"Ask questions about company information, contracts, employees, and products.\")\n",
    "        \n",
    "        with gr.Row():\n",
    "            with gr.Column(scale=3):\n",
    "                chatbot = gr.ChatInterface(\n",
    "                    fn=chat_fn,\n",
    "                    examples=[\n",
    "                        \"Who is the CTO of the company?\",\n",
    "                        \"Who is the CEO of the company?\",\n",
    "                        \"What products does the company offer?\",\n",
    "                    ],\n",
    "                    title=\"\",\n",
    "                    description=\"💬 Chat with the company knowledge base\"\n",
    "                )\n",
    "            \n",
    "            with gr.Column(scale=1):\n",
    "                gr.Markdown(\"### 📊 System Controls\")\n",
    "                load_btn = gr.Button(\"📥 Load Documents\", variant=\"primary\")\n",
    "                stats_btn = gr.Button(\"📈 Get Statistics\")\n",
    "                output_box = gr.Textbox(label=\"System Output\", lines=5)\n",
    "                \n",
    "                load_btn.click(fn=load_data, outputs=output_box)\n",
    "                stats_btn.click(fn=get_stats, outputs=output_box)\n",
    "                \n",
    "                gr.Markdown(f\"\"\"\n",
    "                ### 📝 Instructions:\n",
    "                1. Make sure Ollama is running\n",
    "                2. Click \"Load Sample Documents\" \n",
    "                3. Start asking questions!\n",
    "                \n",
    "                ### 🔧 Required Models:\n",
    "                - `ollama pull {embedding_model}`\n",
    "                - `ollama pull {llm_model}`\n",
    "                \"\"\")\n",
    "    \n",
    "    return demo\n",
    "\n",
    "\n",
    "def main():\n",
    "    \"\"\"Main function to run the RAG system.\"\"\"\n",
    "    print(\"=\" * 60)\n",
    "    print(\"RAG System with Ollama and SQLite\")\n",
    "    print(\"=\" * 60)\n",
    "    \n",
    "    # Initialize RAG system\n",
    "    print(\"\\nInitializing RAG system...\")\n",
    "    rag_system = RAGSystem(\n",
    "        embedding_model=embedding_model,\n",
    "        llm_model=llm_model,\n",
    "        db_path=\"vector_store.db\"\n",
    "    )\n",
    "    \n",
    "    print(\"\\n⚠️  Make sure Ollama is running and you have the required models:\")\n",
    "    print(f\"   - ollama pull {embedding_model}\")\n",
    "    print(f\"   - ollama pull {llm_model}\")\n",
    "    print(\"\\nStarting Gradio interface...\")\n",
    "    \n",
    "    # Create and launch Gradio interface\n",
    "    demo = create_gradio_interface(rag_system)\n",
    "    demo.launch(share=False)\n",
    "\n",
    "\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01b4ff0e-36a5-43b5-8ecf-59e42a18a908",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
